from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np  
from sklearn.metrics import auc, roc_curve, roc_auc_score, recall_score, precision_score, accuracy_score
import json
import os
from peft import PeftModel

support_model = ["llama-7b", "contam-1.4b"]

def load_model(args):
    model_name = args.model_name.split("/")[-1]
    print("model_name: ", model_name)
    if model_name in support_model:
        model = AutoModelForCausalLM.from_pretrained(args.model_name, return_dict=True, device_map='auto')
        model.eval()
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    else:
        #TODO
        model = None
        tokenizer = None
        
    if(args.fine_tuning):
        model = PeftModel.from_pretrained(model, args.lora_name)
    return model, tokenizer

def load_data(args): 
    dataset_name = args.dataset_name.split("/")[-1]
    if "WikiMIA" in dataset_name:
        # load data from huggingface
        assert args.length is not None and args.length in [32, 64, 128, 256], "args.length must be one of [32, 64, 128, 256]"
        dataset = load_dataset(args.dataset_name, split=f"WikiMIA_length{args.length}") # load_dataset f
        data = convert_huggingface_data_to_list_dic(dataset)
    elif dataset_name.endswith(".json"):
        # For loading a JSON-serialized list of examples.
        print("loading from json...")
        with open(args.dataset_name, "r") as f:
            data = f.read()
            data = json.loads(data)
    elif dataset_name.endswith(".txt") or dataset_name.endswith(".jsonl"):
        # For loading a dataset where each example is on its own line. 
        with open(args.dataset_name, "r") as f:
            data = f.readlines()
    else:
        raise ValueError("Invalid dataset name")
    return data
        
def convert_huggingface_data_to_list_dic(dataset):
    all_data = []
    for i in range(len(dataset)):
        example = dataset[i]
        all_data.append(example)
    return all_data

def sweep(prediction, labels):
    """
    Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
    """
    fpr, tpr, _ = roc_curve(labels, -prediction)
    acc = np.max(1-(fpr+(1-tpr))/2)
    return fpr, tpr, auc(fpr, tpr), acc

def do_plot(prediction, answers, sweep_fn=sweep, metric='auc', legend="", output_dir=None, val_threshold=None):
    """
    Generate the ROC curves by using ntest models as test models and the rest to train.
    """
    fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))

    low = tpr[np.where(fpr<.05)[0][-1]] # TPR@5%FPR
    
    # choose the best threshold for the best accuracy
    # 选择最佳阈值以获得最佳准确率
    prediction = [-x for x in prediction]
    fpr, tpr, thresholds = roc_curve(np.array(answers, dtype=bool), np.array(prediction))
    
    accuracies = []
    for threshold in thresholds:
        predictions = (np.array(prediction) >= threshold).astype(int)
        accuracies.append(accuracy_score(np.array(answers, dtype=bool), predictions))
    
    best_threshold_index = np.argmax(accuracies)
    best_threshold = thresholds[best_threshold_index]
    
    #------------------------ use the best threshold from val dataset to predict -----------------
    if(val_threshold is not None):
        best_threshold = val_threshold    
    #------------------------ use the best threshold from val dataset to predict -----------------
    
    # 使用最佳阈值计算预测结果
    best_predictions = (np.array(prediction) >= best_threshold).astype(int)

    # 计算准确率、召回率和精确率
    accuracy = accuracy_score(np.array(answers, dtype=bool), best_predictions)
    recall = recall_score(np.array(answers, dtype=bool), best_predictions)
    precision = precision_score(np.array(answers, dtype=bool), best_predictions)
    
    print('Attack %s   AUC %.4f, Accuracy %.4f, TPR@5%%FPR of %.4f\n'%(legend, auc,acc, low))
    print('accuracy: %.4f, Recall: %.4f, Precision: %.4f\n' % (accuracy, recall, precision))

    metric_text = ''
    if metric == 'auc':
        metric_text = 'auc=%.3f'%auc
    elif metric == 'acc':
        metric_text = 'acc=%.3f'%acc

    plt.plot(fpr, tpr, label=legend+metric_text)
    return accuracy,best_threshold, legend, auc,acc, low

def fig_fpr_tpr(all_output, output_dir, val_threshold=None):
    print("output_dir", output_dir)
    os.makedirs(output_dir, exist_ok=True)
    answers = []
    metric2predictions = defaultdict(list)
    for ex in all_output:
        answers.append(ex["label"])
        for metric in ex["pred"].keys():
            if ("raw" in metric) and ("clf" not in metric):
                continue
            metric2predictions[metric].append(ex["pred"][metric])
            
    threshold = []
    plt.figure(figsize=(4,3))
    index = 0
    with open(f"{output_dir}/auc.txt", "w") as f:
        for metric, predictions in metric2predictions.items():
            threshold_select=val_threshold[index] if val_threshold is not None else None
            index +=1
            accuracy, best_threshold, legend, auc, acc, low = do_plot(predictions, answers, legend=metric, metric='auc', output_dir=output_dir, val_threshold=threshold_select)
            f.write('%s   AUC %.4f, Accuracy %.4f, TPR@5%%FPR of %.4f, threshold %.4f, eval_acc %.4f\n'%(legend, auc, acc, low, best_threshold, accuracy))
            threshold.append(best_threshold)
            
    plt.semilogx()
    plt.semilogy()
    plt.xlim(1e-5,1)
    plt.ylim(1e-5,1)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.plot([0, 1], [0, 1], ls='--', color='gray')
    plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
    plt.legend(fontsize=8)
    plt.savefig(f"{output_dir}/auc.png")
    return threshold

def save_prediction_to_file(predictions, output_dir):
    predictions_total = defaultdict(list)
    answers = []
    for ex in predictions: 
        answers.append(ex["label"])
        for metric in ex["pred"].keys():
            predictions_total[metric].append(ex["pred"][metric])
    metric={"prediction":predictions_total, "answer":answers}
    print("save_output_dir", output_dir)
    with open(f"{output_dir}/predictions.json", "w") as f:
        json.dump(metric, f)